import numpy as np
import h5py


def load_file():
    train_dataset = h5py.File("datasets/train_dataset.hdf5", "r")
    train_set_x_orig = np.array(train_dataset["train_set_x"][:])  #读取训练集中的图片特征m1*64*64*3
    train_set_y_orig = np.array(train_dataset["train_set_y"][:])  #读取训练集中的真值标签1*m1

    test_dataset = h5py.File("datasets/test_dataset.hdf5", "r")
    test_set_x_orig = np.array(test_dataset["test_set_x"][:])  #读取测试集中的图片特征m2*64*64*3
    test_set_y_orig = np.array(test_dataset["test_set_y"][:])  #读取测试集中的真值标签1*m2

    classes = np.array(test_dataset["list_class"][:])  #读取数据集类别

    # train_set_x_orig = train_set_x_orig.reshape((train_set_x_orig.shape[0], -1)).T
    # test_set_x_orig = test_set_x_orig.reshape((test_set_x_orig.shape[0], -1)).T

    train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
    test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))
    return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes